// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this
// file except in compliance with the License. You may obtain a copy of the
// License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.

#include "flare/rpc/protocol/protobuf/qzone_protocol.h"

#include <optional>
#include <utility>

#include "flare/base/buffer/zero_copy_stream.h"
#include "flare/base/down_cast.h"
#include "flare/base/likely.h"
#include "flare/base/overloaded.h"
#include "flare/rpc/protocol/protobuf/call_context.h"
#include "flare/rpc/protocol/protobuf/call_context_factory.h"
#include "flare/rpc/protocol/protobuf/detail/qzone_header.h"
#include "flare/rpc/protocol/protobuf/detail/utility.h"
#include "flare/rpc/protocol/protobuf/message.h"
#include "flare/rpc/protocol/protobuf/rpc_meta.pb.h"
#include "flare/rpc/protocol/protobuf/rpc_options.h"
#include "flare/rpc/protocol/protobuf/service_method_locator.h"

namespace flare::protobuf {

FLARE_RPC_REGISTER_CLIENT_SIDE_STREAM_PROTOCOL_ARG("qzone-pb", QZoneProtocol,
                                                   false);
FLARE_RPC_REGISTER_SERVER_SIDE_STREAM_PROTOCOL_ARG("qzone-pb", QZoneProtocol,
                                                   true);

namespace {

std::pair<std::int32_t, std::int32_t> TryGetQzoneMethodKey(
    const google::protobuf::MethodDescriptor* method) {
  auto service = method->service();
  auto sid = TryGetQZoneServiceId(service), mid = TryGetQZoneMethodId(method);
  if (sid && mid) {
    return std::pair(*sid, *mid);
  }
  return {};
}

void RegisterMethodCallback(const google::protobuf::MethodDescriptor* method) {
  if (auto keys = TryGetQzoneMethodKey(method); keys.first && keys.second) {
    ServiceMethodLocator::Instance()->RegisterMethod(protocol_ids::qzone,
                                                     method, keys);
  }
}

void DeregisterMethodCallback(
    const google::protobuf::MethodDescriptor* method) {
  if (auto keys = TryGetQzoneMethodKey(method); keys.first && keys.second) {
    ServiceMethodLocator::Instance()->DeregisterMethod(protocol_ids::qzone,
                                                       method);
  }
}

}  // namespace

FLARE_RPC_PROTOCOL_PROTOBUF_REGISTER_METHOD_PROVIDER(RegisterMethodCallback,
                                                     DeregisterMethodCallback);

namespace {

using MethodDesc = ServiceMethodLocator::MethodDesc<protocol_ids::QZone>;

struct OnWireMessage : Message {
  OnWireMessage() { SetRuntimeTypeTo<OnWireMessage>(); }
  std::uint64_t GetCorrelationId() const noexcept override {
    return header.head.serialNo;
  }
  Type GetType() const noexcept override {
    return stream ? Type::Stream : Type::Single;
  }

  // Undefined for client-side protocol. For server-side, if the requested
  // version / cmd is not recognized, it's nullptr.
  const ServiceMethodLocator::MethodDesc<protocol_ids::QZone>* method_desc;
  NoncontiguousBuffer payload;
  bool stream;  // Meaningless for client side as we cannot reliably detect it.
  qzone::QzoneProtocol header;  // This field must be the last one as
                                // `qzone::QZoneProtocol` uses flexible array..
};

StreamProtocol::Characteristics characteristics = {
    .name = "QZone/Protobuf",
    .no_end_of_stream_marker = true,
    .ignore_message_type_for_client_side_streaming = true};

}  // namespace

const StreamProtocol::Characteristics& QZoneProtocol::GetCharacteristics()
    const {
  return characteristics;
}

const MessageFactory* QZoneProtocol::GetMessageFactory() const {
  return &error_message_factory;
}

const ControllerFactory* QZoneProtocol::GetControllerFactory() const {
  return &passive_call_context_factory;
}

StreamProtocol::MessageCutStatus QZoneProtocol::TryCutMessage(
    NoncontiguousBuffer* buffer, std::unique_ptr<Message>* message) {
  if (buffer->ByteSize() < sizeof(qzone::QzoneProtocol)) {
    return MessageCutStatus::NotIdentified;
  }

  qzone::QzoneProtocol header;
  FlattenToSlow(*buffer, &header, sizeof(header));
  if (header.soh != QzoneProtocolSOH) {
    return MessageCutStatus::ProtocolMismatch;
  }
  header.head.Decode();  // `ntohl` stuff.

  // Try recognize the method being called, and see if the method being called
  // is a streaming one.
  //
  // FIXME: From my inspection, code generated by compiler for this branch is a
  // bit of awful, perhaps we should separate server-side & client-side protocol
  // into two classes.
  auto msg = std::make_unique<OnWireMessage>();
  if (server_side_) {
    // `header.head` is not properly aligned, using them to construct method key
    // annoys ASAN. So we copy them first.
    auto version = header.head.version;
    auto cmd = header.head.cmd;
    // This lookup hurts performance.
    auto desc = ServiceMethodLocator::Instance()->TryGetMethodDesc(
        protocol_ids::qzone, {version, cmd});
    if (desc) {
      FLARE_CHECK(
          !desc->method_desc->client_streaming(),
          "Client-side streaming is not implemented for QZone protocol.");
      msg->stream = desc->method_desc->server_streaming();
    } else {
      // Not recognized then.
      return MessageCutStatus::ProtocolMismatch;
    }
    msg->method_desc = desc;
  } else {
    // This is ignored by `StreamCallGate`.
    //
    // @sa: `ignore_message_type_for_client_side_streaming` in
    // `StreamProtocol::Characteristics`.
    msg->stream = false;
  }

  auto size = header.head.len;
  if (buffer->ByteSize() < size) {
    return MessageCutStatus::NeedMore;
  }
  auto cut = buffer->Cut(size);
  cut.Skip(offsetof(qzone::QzoneProtocol, body));  // Skip the header.

  memcpy(reinterpret_cast<void*>(&msg->header),
         reinterpret_cast<const void*>(&header), sizeof(header));
  msg->payload = cut.Cut(cut.ByteSize() - 1);  // EOT byte.

  *message = std::move(msg);
  return MessageCutStatus::Cut;
}

// Similar to `StdProtocol::TryParse`.
bool QZoneProtocol::TryParse(std::unique_ptr<Message>* message,
                             Controller* controller) {
  auto on_wire = cast<OnWireMessage>(message->get());
  auto meta = object_pool::Get<rpc::RpcMeta>();
  MaybeOwning<google::protobuf::Message> unpack_to;
  auto msg = std::make_unique<ProtoMessage>();
  bool accept_msg_in_bytes = false;

  meta->set_correlation_id(on_wire->header.head.serialNo);

  if (server_side_) {
    // For server side we can tell if the message belong to a stream by its
    // method key (as of done in `TryCutMessage`.).
    if (on_wire->stream) {
      meta->set_method_type(rpc::METHOD_TYPE_STREAM);
      // We only support server-side streaming, so there must be only one
      // message in the request.
      meta->set_flags(rpc::MESSAGE_FLAGS_START_OF_STREAM |
                      rpc::MESSAGE_FLAGS_END_OF_STREAM);
    } else {
      meta->set_method_type(rpc::METHOD_TYPE_SINGLE);
    }

    auto&& req_meta = *meta->mutable_request_meta();
    auto version = on_wire->header.head.version;
    auto cmd = on_wire->header.head.cmd;
    auto desc = on_wire->method_desc;
    if (FLARE_UNLIKELY(!desc)) {
      FLARE_LOG_WARNING_EVERY_SECOND("Unrecognized QZone version/cmd: {}/{}.",
                                     version, cmd);
      *message = std::make_unique<EarlyErrorMessage>(
          meta->correlation_id(), rpc::STATUS_METHOD_NOT_FOUND,
          fmt::format("[{}/{}] (Version/CMD) is not recognized.",
                      on_wire->header.head.version, on_wire->header.head.cmd));
      return true;
    }
    // libstdc++ (pre-GCC5 ABI) does COW here, which hurt performance in our
    // case. (~6.7% down in QPS)
    //
    // req_meta.set_method_name(desc->normalized_method_name);
    req_meta.mutable_method_name()->assign(desc->normalized_method_name.begin(),
                                           desc->normalized_method_name.end());
    unpack_to = MaybeOwning(owning, desc->request_prototype->New());
  } else {
    // For client side, we have no idea if it belongs to a stream. So we believe
    // what we were told.
    auto ctx = cast<ProactiveCallContext>(controller);
    meta->set_method_type(ctx->expecting_stream ? rpc::METHOD_TYPE_STREAM
                                                : rpc::METHOD_TYPE_SINGLE);

    // `meta.flags()` is not set, we're unable to infer it anyway.
    auto&& resp_meta = *meta->mutable_response_meta();

    // We carefully chose same enumerator as `common/rpc/rpc_error_code.proto`,
    // and since `serverResponseInfo` was interpreted as enumerators in
    // `rpc_error_code.proto`, it should fit here.
    //
    // Note that, however, since we allow user-defined status to be used, the
    // value here does not necessarily be a valid enumerator of `rpc::Status`.
    resp_meta.set_status(on_wire->header.head.serverResponseInfo);

    if (ctx->expecting_stream && on_wire->payload.Empty()) {
      // No body at all. This should be an erroneous reply to a streaming call.
      meta->set_flags(rpc::MESSAGE_FLAGS_START_OF_STREAM |
                      rpc::MESSAGE_FLAGS_END_OF_STREAM);
      // `unpack_to` is left `nullptr`.
    } else {
      accept_msg_in_bytes = ctx->accept_response_in_bytes;
      if (FLARE_LIKELY(!accept_msg_in_bytes)) {
        unpack_to = ctx->GetOrCreateResponse();
      }
    }
  }

  if (FLARE_UNLIKELY(accept_msg_in_bytes)) {
    msg->msg_or_buffer = std::move(on_wire->payload);
  } else {
    if (FLARE_LIKELY(unpack_to)) {
      NoncontiguousBufferInputStream nbis(&on_wire->payload);
      if (FLARE_UNLIKELY(!unpack_to->ParseFromZeroCopyStream(&nbis))) {
        FLARE_LOG_WARNING_EVERY_SECOND(
            "Failed to parse message (correlation id {}).",
            meta->correlation_id());
        return false;
      }
      msg->msg_or_buffer = std::move(unpack_to);
    }
  }

  msg->meta = std::move(meta);
  *message = std::move(msg);
  return true;
}

void QZoneProtocol::WriteMessage(const Message& message,
                                 NoncontiguousBuffer* buffer,
                                 Controller* controller) {
  auto msg = cast<ProtoMessage>(message);
  auto&& meta = *msg->meta;
  qzone::QzoneProtocol header;

  FLARE_LOG_ERROR_IF_ONCE(
      !controller->GetTracingContext().empty() ||
          controller->IsTraceForciblySampled(),
      "Passing tracing context is not supported by QZone protocol.");

  header.soh = QzoneProtocolSOH;
  header.head.len = offsetof(qzone::QzoneProtocol, body) +
                    0 /* Payload size, filled later. */ + 1 /* EOT */;
  SafeCopy(msg->GetCorrelationId(), &header.head.serialNo);

  // We only fill `version` / `cmd` for request. It's technically possible for
  // us to remember them in `ProtoMessage` and passing them around so as to also
  // fill them in response packet, but it doesn't make too much sense IMO
  // (assuming without these field we're still able to talk with old servers.).

  // For streaming RPCs, the messages are written exactly the same way as normal
  // RPCs. QZone protocol does not have dedicated fields for indicating
  // streaming calls.

  if (server_side_) {  // FIXME: Branch predictor will likely have a hard time
                       // in predicting this. Maybe we should use two different
                       // class for client side & server side protocol?
    auto&& resp_meta = meta.response_meta();
    header.head.serverResponseFlag = (resp_meta.status() == rpc::STATUS_SUCCESS)
                                         ? qzone::QzoneServerSucc
                                         : qzone::QzoneServerFailed;

    // Using temporary `status` to comfort ASAN. `serverResponseInfo` is not
    // properly aligned.
    decltype(header.head.serverResponseInfo) status;
    SafeCopy(resp_meta.status(), &status);
    header.head.serverResponseInfo = status;
  } else {
    auto ctx = cast<ProactiveCallContext>(controller);
    auto desc = TryGetQzoneMethodKey(ctx->method);
    FLARE_CHECK(
        desc.first,
        "You didn't set option `qzone_service_id` or `qzone_method_id` for "
        "method [{}], which are required for calling it via QZone protocol.",
        meta.request_meta().method_name());
    header.head.version = desc.first;
    header.head.cmd = desc.second;
  }

  // Check-summing is NOT done. It's a bit hard to passing option of whether
  // check-summing is enabled from `ServerDescriptor.options()` to here, and,
  // AFAICT, our old RPC framework didn't check about it.

  // Build buffers.
  NoncontiguousBufferBuilder builder;
  auto header_ptr = builder.Reserve(sizeof(header));
  header.head.len += WriteTo(msg->msg_or_buffer, &builder);
  FLARE_LOG_ERROR_IF_ONCE(
      !msg->attachment.Empty(),
      "Attachment is not supported by QZone protocol. Dropped silently.");
  builder.Append(static_cast<char>(QzoneProtocolEOT));

  header.head.Encode();  // `htonl` stuff.
  memcpy(header_ptr, &header, sizeof(header));

  buffer->Append(builder.DestructiveGet());
}

}  // namespace flare::protobuf
